{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2844f430",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39fca96",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers[sentencepiece] --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4a6b08",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "ROOT_DIR = os.path.abspath(\"../../../\")\n",
    "sys.path.append(ROOT_DIR)\n",
    "lib_path = os.path.join(ROOT_DIR, './build/lib/libth_transformer.so')\n",
    "\n",
    "# disable warning in notebook\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f1a9a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import configparser\n",
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "from transformers import PreTrainedTokenizerFast\n",
    "from transformers import BartForConditionalGeneration, BartTokenizer \n",
    "from transformers import MBartForConditionalGeneration, MBartTokenizer \n",
    "from examples.pytorch.bart.utils.ft_encoder import FTBartEncoderWeight, FTBartEncoder\n",
    "from examples.pytorch.bart.utils.ft_decoding import FTBartDecodingWeight, FTBartDecoding, FTBart"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b057bdb2",
   "metadata": {},
   "source": [
    "## Setup HuggingFace BART/MBART Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c7009c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify model name or checkpoint path\n",
    "# model_name = 'facebook/bart-base' # BART\n",
    "model_name = 'facebook/mbart-large-50' # mBART\n",
    "\n",
    "if 'mbart' not in model_name:\n",
    "    model = BartForConditionalGeneration.from_pretrained(model_name)\n",
    "    tokenizer = BartTokenizer.from_pretrained(model_name)\n",
    "    layernorm_type = \"post_layernorm\"\n",
    "else:\n",
    "    model = MBartForConditionalGeneration.from_pretrained(model_name)\n",
    "    tokenizer = MBartTokenizer.from_pretrained(model_name)\n",
    "    layernorm_type = \"pre_layernorm\"\n",
    "is_mbart = model.config.add_final_layer_norm\n",
    "model = model.eval().to('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9e42a7c",
   "metadata": {},
   "source": [
    "## Setup FT BART Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20b421dc",
   "metadata": {},
   "source": [
    "### FT parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0777020",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = model.config\n",
    "activation_type = config.activation_function\n",
    "# single-gpu so set TP=1, PP=1\n",
    "tensor_para_size = 1\n",
    "pipeline_para_size = 1\n",
    "bart_with_bias = True\n",
    "use_gated_activation = False\n",
    "position_embedding_type = 1 # absolute positional embedding\n",
    "weight_data_type = np.float32\n",
    "encoder_head_size = config.d_model // config.encoder_attention_heads\n",
    "decoder_head_size = config.d_model // config.decoder_attention_heads\n",
    "remove_padding = False\n",
    "use_fp16 = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40b63dee",
   "metadata": {},
   "source": [
    "### Load layer weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00f997c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_encoder_weight = FTBartEncoderWeight(\n",
    "    config,\n",
    "    tensor_para_size,\n",
    "    pipeline_para_size,\n",
    "    bart_with_bias=bart_with_bias,\n",
    "    mbart=is_mbart,\n",
    "    use_gated_activation=use_gated_activation,\n",
    "    position_embedding_type=position_embedding_type,\n",
    "    weight_data_type=weight_data_type,\n",
    ")\n",
    "ft_encoder_weight.load_from_model(model.float())\n",
    "\n",
    "ft_decoding_weight = FTBartDecodingWeight(\n",
    "    config,\n",
    "    tensor_para_size,\n",
    "    pipeline_para_size,\n",
    "    bart_with_bias=bart_with_bias,\n",
    "    mbart=is_mbart,\n",
    "    use_gated_activation=use_gated_activation,\n",
    "    position_embedding_type=position_embedding_type,\n",
    "    weight_data_type=weight_data_type,\n",
    ")\n",
    "ft_decoding_weight.load_from_model(model.float())\n",
    "\n",
    "if use_fp16:\n",
    "    ft_encoder_weight.to_half()\n",
    "    ft_decoding_weight.to_half()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e89e563",
   "metadata": {},
   "source": [
    "### Setup Encoder, Decoder, and E2E model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb887e25",
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_encoder = FTBartEncoder(ft_encoder_weight.w, lib_path, config.encoder_attention_heads,\n",
    "                        encoder_head_size, config.encoder_ffn_dim,\n",
    "                        config.d_model, remove_padding, config.encoder_layers, \n",
    "                        tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, \n",
    "                        bart_with_bias=bart_with_bias, mbart=is_mbart,\n",
    "                        position_embedding_type=position_embedding_type, \n",
    "                        activation_type=activation_type, layernorm_type=layernorm_type)\n",
    "\n",
    "ft_decoding = FTBartDecoding(ft_decoding_weight.w, lib_path,\n",
    "                        config.decoder_attention_heads, decoder_head_size,\n",
    "                        config.decoder_ffn_dim, config.d_model,\n",
    "                        config.d_model, config.decoder_layers,\n",
    "                        config.decoder_start_token_id, config.eos_token_id, config.vocab_size,\n",
    "                        tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, \n",
    "                        bart_with_bias=bart_with_bias, mbart=is_mbart,\n",
    "                        position_embedding_type=position_embedding_type, \n",
    "                        activation_type=activation_type, layernorm_type=layernorm_type)\n",
    "\n",
    "ft_bart = FTBart(ft_encoder, ft_decoding)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fa9d8f7",
   "metadata": {},
   "source": [
    "## Example input and Inference parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963c881b",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "input_len = 512\n",
    "inputs = {\n",
    "    'input_ids': torch.randint(0, config.vocab_size, size=(batch_size, input_len)).to(\"cuda\"),\n",
    "    'attention_mask': torch.ones(size=(batch_size, input_len)).to(\"cuda\")    \n",
    "}\n",
    "\n",
    "# or use tokenized text as input\n",
    "# text = [\n",
    "#     \"FasterTransformer is a library implementing an accelerated engine for the inference of transformer-based neural networks, with a special emphasis on large models, spanning many GPUs and nodes in a distributed manner.\"\n",
    "# ]\n",
    "# batch_size = len(text)\n",
    "# inputs = tokenizer(text, padding=True, return_tensors=\"pt\")\n",
    "# input_len = inputs['input_ids'].size(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fcd77c",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_output_len = 32\n",
    "ft_max_output_len = max_output_len - 2  # to achieve identical results w/ HF, exclude start & end tokens\n",
    "num_beams = 2\n",
    "beam_search_diversity_rate = 0.0\n",
    "topk = None\n",
    "topp = None\n",
    "measurement_iters = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3bd58c0",
   "metadata": {},
   "source": [
    "## HF output and timing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff08e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_fp16:\n",
    "    model.half()\n",
    "else:\n",
    "    model.float()\n",
    "hf_outputs = model.generate(inputs['input_ids'], max_length=max_output_len, num_beams=num_beams)\n",
    "hf_tokens = tokenizer.batch_decode(hf_outputs, skip_special_tokens=True)\n",
    "# print(\"HF output ids\",hf_outputs)\n",
    "# print(\"HF output text\",hf_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7412f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_latencies = []\n",
    "for _ in range(measurement_iters):\n",
    "    start_time = time.time()\n",
    "    model.generate(inputs['input_ids'], max_length=max_output_len, num_beams=num_beams, use_cache=True)\n",
    "    end_time = time.time()\n",
    "    hf_latencies.append(end_time - start_time)\n",
    "hf_p50 = np.percentile(hf_latencies, 50)\n",
    "hf_p99 = np.percentile(hf_latencies, 99)\n",
    "print(f\"HF p50: {hf_p50*1000:.2f} ms, p99: {hf_p99*1000:.2f} ms \")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7639f0e7",
   "metadata": {},
   "source": [
    "## FT output and timing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5d9747a",
   "metadata": {},
   "outputs": [],
   "source": [
    "return_dict = ft_bart(inputs['input_ids'],\n",
    "                      inputs['attention_mask'],\n",
    "                      inputs_embeds=None,\n",
    "                      beam_size=num_beams,\n",
    "                      max_seq_len=ft_max_output_len,\n",
    "                      top_k=topk,\n",
    "                      top_p=topp,\n",
    "                      beam_search_diversity_rate=beam_search_diversity_rate,\n",
    "                      is_return_output_log_probs=False,\n",
    "                      is_return_cum_log_probs=False)\n",
    "\n",
    "# ft_bart returns output_ids of shape [batch_size, beam_width, max_output_seq_len]\n",
    "# ft_bart returns sequence_length of shape [batch_size, beam_width]\n",
    "ft_output_ids = return_dict['output_ids']\n",
    "ft_sequence_length = return_dict['sequence_lengths']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8fecf41",
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_outputs = []\n",
    "for i in range(batch_size):\n",
    "    # selecting the top sequence from beam width number of sequences\n",
    "    ft_outputs.append(list(ft_output_ids[i, 0, :][1:ft_sequence_length[i , 0]])) # start from 1 to exclude the 1st token\n",
    "ft_tokens = tokenizer.batch_decode(ft_outputs, skip_special_tokens=True)\n",
    "# print(\"FT output ids\", ft_outputs)\n",
    "# print(\"FT output text\", ft_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3662ca5a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ft_latencies = []\n",
    "for _ in range(measurement_iters):\n",
    "    start_time = time.time()\n",
    "    return_dict = ft_bart(inputs['input_ids'],\n",
    "                          inputs['attention_mask'],\n",
    "                          inputs_embeds=None,\n",
    "                          beam_size=num_beams,\n",
    "                          max_seq_len=ft_max_output_len,\n",
    "                          top_k=topk,\n",
    "                          top_p=topp,\n",
    "                          beam_search_diversity_rate=beam_search_diversity_rate,\n",
    "                          is_return_output_log_probs=False,\n",
    "                          is_return_cum_log_probs=False)\n",
    "    end_time = time.time()\n",
    "    ft_latencies.append(end_time - start_time)\n",
    "ft_p50 = np.percentile(ft_latencies, 50)\n",
    "ft_p99 = np.percentile(ft_latencies, 99)\n",
    "print(f\"FT p50: {ft_p50*1000:.2f} ms, p99: {ft_p99*1000:.2f} ms \")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb35c608",
   "metadata": {},
   "source": [
    "## Performance summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ea1e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Precision: {'FP16' if use_fp16 else 'FP32'}\")\n",
    "print(f\"Input length: {input_len}, Output length: {max_output_len}\")\n",
    "print(f\"HF p50: {hf_p50*1000:.2f} ms, p99: {hf_p99*1000:.2f} ms \")\n",
    "print(f\"FT p50: {ft_p50*1000:.2f} ms, p99: {ft_p99*1000:.2f} ms \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc74bd3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
